
rm(list=ls())
library(doParallel)
library(doRNG)
library(foreach)
library(abind)

options(error=stop)
source("2.0.0 MyFunc.R")
source("2.0.1 GetProb.R")
source("2.0.3 GetContrast.R")
source("2.1 DataGen.R")
source("2.1.2 MLEst3.R")
source("2.2.3 PMLEst4.R")



SEED.index = 1 # DEBUG
    
### Start of true values specification
#   parameters
n.cand      = c(100, 200, 500, 1000)
numsims     = 500
#### Parameters that has statble properties
alpha.true  = t(matrix(rep(c(0,0.7),5),2,5)) #CHANGE
beta.3 =  c(-0.5, 1)
beta.1 = beta.2 = beta.4 = beta.5 =c(-0.5, 0.1);
beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))


### Parameters to illustrate g-null paradox
# alpha.true  = t(matrix(rep(c(0,0),5),2,5))
# beta.3 =  c(-0.5, 0.1)
# beta.1 = c(-0.5, -1); beta.2 = c(-0.5, 1.5)
# beta.4 = c(-0.5, 1.5); beta.5 = c(-0.5, -2)
# beta.true   = t(matrix(c(beta.1, beta.2, beta.3, beta.4, beta.5),2,5))

# datagen = DataGen(1000)
# MyAttach(datagen)
# cnt

gamma.true  = t(matrix(c(0.1,-0.5,NA,NA,0.1,-0.5,0.1,-0.5),4,2))
est.method.cand = c("MLE3","PMLE4","DR")
params.true = list(alpha.true=alpha.true, beta.true=beta.true,
                   gamma.true=gamma.true)
rownames(gamma.true) = c("a0","a1")

### To store results
sim.name = paste("sim=",1:numsims,sep="")
n.name   = paste("n=",n.cand,sep="")
est.method.name = as.character(est.method.cand)
results  = array(dim=c(numsims,length(n.cand), length(est.method.cand),
                       nrow(alpha.true), ncol(alpha.true),2,4),
                  dimnames = list(sim.name,n.name,est.method.name,
                                  1:nrow(alpha.true), 1:ncol(alpha.true),
                                 c("alpha","beta"),c("est","nll","time",
                                                     "contrast")))
names(dimnames(results)) = c("Simulation number","Sample size","est.method",
                             "1-5","1/2","Param","Measure")

MESSAGE <-- FALSE     # compress message (This is a global variable!)

ptm0 = proc.time()


    
n.index = 4

ptm1 = proc.time()
n = n.cand[n.index]

cat("Sample size: ",n,"\n")
set.seed(n)
seeds       = runif(10000,-2^30,2^30)
{
    ee.assume.true.alpha = T; 
    
    # Nuisance model for y.hat.type for DR estimation; 
    y.hat.type = "MLE" # "saturated" or "MLE" or "logit" or "diy"
    prop.score.type = "correct" # "correct" or "incorrect"
    vectorize.get.prob = F;parallel.root = F
    l0.type = "discrete" # "continuous" or "discrete"
    l1.type ="correct" # "correct" or "incorrect" or "diy"
    gop.type = "incorrect" # "correct" or "incorrect"
}

{
    doMLE=0; doPMLE=1; 
    doDR=1; 
    doGComp=1;
    doDebugEE=0
}
{
    source("2.5 CallDR.R")
    source("2.5.1 DR.R")
    source("2.5.1 DR_Iter.R")
    source("2.5.3 DR_MyFunc.R")
    source("2.6 DR_EE_value.R")
}

GComp.res = array(dim=c(numsims, 8))
{ee.res = array(dim=c(numsims,5*2))
ee.stage1 = ee.d1 = ee.d0 =array(dim=c(numsims,5*2))
ee.u1 = ee.u0 =  numeric(numsims)
}
# numsims = 500
M = 5; per.iter = numsims/M

### 25 Jul 2021: run 10 times DR estimation.
dr.res.10 = array(dim=c(10, dim(results[,4,"PMLE4",,,,])))
dimnames(dr.res.10) = list(Initial = 1:10, sim.name, 1:nrow(alpha.true), 1:ncol(alpha.true),
                            c("alpha","beta"),c("est","nll","time",
                                                "contrast"))
names(dimnames(dr.res.10) )[2:6] = c("Simulation Number", "1-5","1/2","Param","Measure")



t0.total = proc.time()
for (iter in 1:M){
    # iter = 1
    boot.list = ((iter-1)*per.iter+1):(iter*per.iter)
    # for(i in boot.list){
    boot1 = foreach(i = boot.list,.errorhandling="stop") %dopar% {
        # i= 1 #DEBUG
        
        seed.perturb = 0
        seed.index = i + (SEED.index-1) * numsims +seed.perturb
        set.seed(seeds[seed.index])
        cat("\t Simulation number: ",seed.index-seed.perturb,"\n")    
        
        ### 0: Data generation: x, y
        source("2.1 DataGen.R")

        datagen = DataGen(n, SEED=seeds[seed.index])
        MyAttach(datagen)
        
        ### 1 = Maximal likelihood estimate
        
        mle3 = NULL
        if (doMLE){
            set.seed(seeds[seed.index])
          t0 = proc.time()
          mle3 = MLEst3(data, cnt)
          t1 = proc.time()
          print(paste0("MLE time: ", (t1 - t0)[3], " secs")) # 403.982 seconds
          results[i,n.index,"MLE3",,,,] = mle3
        }
        
        pmle4 = NULL
        source("2.0.1 GetProb.R")
        if (doPMLE){
            set.seed(seeds[seed.index])
          t2 = proc.time()
          pmle4 = PMLEst4(data, cnt)
          results[i,n.index,"PMLE4",,,,] = pmle4
          t3 = proc.time()
          t3 - t2 # 360.650 
          print(paste0("PMLE time: ", round(as.numeric(t3 - t2)[3],2), " secs")) 
        }
        (alpha.pmle = pmle4[,,1,1])
        (beta.pmle = pmle4[,,2,1])
        
        
        ### 2 = DR Estimate based on Vansteelandt and Joffe (2014)
        #   2.1 Estimate gamma propensity score model
        if (prop.score.type == "correct"){
          gamma.a0 = glm.fit(a0.cov,a0.vec,family=binomial(),intercept=FALSE)$coeff
          gamma.a1 = glm.fit(a1.cov,a1.vec,family=binomial(),intercept=FALSE)$coeff
        }else{
          a0.cov.incor = a0.cov; a1.cov.incor = a1.cov
          a0.cov.incor[,2] = 0; a1.cov.incor[,2:ncol(a1.cov.incor)] = 0
          gamma.a0 = glm.fit(a0.cov.incor,a0.vec,family=binomial(),intercept=FALSE)$coeff
          gamma.a1 = glm.fit(a1.cov.incor,a1.vec,family=binomial(),intercept=FALSE)$coeff
          gamma.a0[2] = 0
          gamma.a1[2:ncol(a1.cov.incor)] = 0
        }
        gamma = t(matrix(c(gamma.a0,NA,NA,gamma.a1),4,2))
        rownames(gamma) = c("a0","a1")
        #   2.2 Estimate outcome regression for y
        {source("2.5 CallDR.R")
            source("2.5.1 DR.R")
            source("2.5.1 DR_Iter.R")
            source("2.5.3 DR_MyFunc.R")
            source("2.0.0 MyFunc.R")
            source("2.0.1 GetProb.R")}
        
        
        
        
        
        dr.sol = NULL
        if (doDR){
          if (doPMLE==F){
              set.seed(seeds[seed.index])
            pmle4 = PMLEst4(data, cnt)
            results[i,n.index,"PMLE4",,,,] = pmle4
            # pmle4 = results[i,n.index,"PMLE4",,,,] #DEBUG
          }
          
          
          (alpha.pmle = pmle4[,,1,1])
          beta.pmle = pmle4[,,2,1]
          
          if (gop.type == "incorrect"){
              tmp = rm.l0.tilde(data, cnt)
              data = tmp$data
              cnt = tmp$cnt
              source("2.5.3 DR_MyFunc.R")
              nb = 32
              vec.1 = rep(1,nb)
              vec.0 = rep(0,nb)
          }
          
          if (y.hat.type == "saturated"){
              # Use saturated model
              y.hat = cnt[17:32]/(cnt[17:32]+cnt[1:16])
          }else if (y.hat.type == "MLE"){
              # Use the invert map
              DataGen.obj = func.DataGen(c(alpha.pmle,beta.pmle,gamma.true), data,cnt)
              y.hat = DataGen.obj$p.y[1:16]
          }else if (y.hat.type == "logit"){
              # Use logistic regression
              y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
          }else if (y.hat.type == "diy"){
              y.hat = c(0.5,0,0,0,0)
          }
          
          if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
              l1.cov = y.cov[,c("l0.1","l0.2","a0")]
              l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
              
          }
          
          #   2.3 Estimate alpha (beta and p0 are given by MLE)
          # j = 1; vals = list() # to output EE value
          set.seed(seeds[seed.index])
          t4 = proc.time()
          results[i,n.index,"DR",,,,] = dr.sol =
            DREst(data,cnt,alpha.pmle,beta.pmle,gamma, y.hat,l1.hat, 
                  SEED = seeds[seed.index], dr.warm="MLE")
          t5 = proc.time()
          print(paste0("DR time: ", t5 - t4))
          # plot(unlist(vals),type='l',
          #      main="EE when alpha,beta use pmle, with saturated OR",
          #      xlab='optim iteration', ylab="EE value")
          # abline(1,0,col="red")
          
          (alpha.dr = dr.sol[, , 1, 1])
          (beta.dr = dr.sol[, , 2, 1])
          # (contrast.dr = dr.sol[, , 1, 4])
          source("2.6 DR_EE_value.R")
          
          # for(j in 1:10){
          #     # j=2
          #     dr.warm = dr.warm.obj = NULL
          #     if (j == 1){
          #         dr.warm="DIY"
          #         dr.warm.obj = alpha.pmle
          #     }else{
          #         dr.warm = "random"
          #     }
          #     dr.res.10[j,i,,,,]= dr.est =DREst(data,cnt,alpha.pmle,beta.pmle,gamma, y.hat,l1.hat, 
          #                                         SEED = seeds[seed.index]+j, dr.warm=dr.warm, dr.warm.obj=dr.warm.obj)
          #     
          #     (alpha.dr = dr.est[, , 1, 1])
          #     (beta.dr = dr.est[, , 2, 1])
          #     ee.val = EE.value(data, cnt, alpha.dr,beta.dr, gamma,
          #                        Hessian=FALSE,y.hat)
          #     dr.res.10[j,i,,,,"nll"] = nll.j = sum(ee.val$tmp^2)
          #     if (nll.j < 0.005){ # Here point.dr.est.MLE$nll is the DR objective value
          #         # Copy and paste this result and stop
          #         if (seed.start < 10){
          #             for (k in (j+1):10){
          #                 dr.res.10[k,i,,,,]= dr.res.10[j,i,,,,]
          #             }
          #         }
          #         break
          #     }
          # }
        }
        
        ############
        GComp.res.i = GComp.res[i,]
        if (doGComp){
          ### Do g-computation for E[Y(a1,a2)|L0=l0]
          # alpha.pmle = pmle4[,,1,1]
          beta.pmle = pmle4[,,2,1]
          y.hat.type = "logit" # "saturated" or "MLE" or "logit" or "diy"
          if (y.hat.type == "logit"){
              # Use logistic regression
              y.hat = glm.fit(y.cov,y.vec,family=binomial(),intercept=FALSE)$coeff
          }   
          y.hat.func = function(l0,a0,l1,a1,y.hat){
              if (class(l0) == "numeric"){
                  res = expit(c(l1,a0,l0,a1) %*% y.hat)
              }else{
                  res = c(expit(cbind(l1,a0,l0,a1) %*% y.hat))
              }
              return(res)
          }
          
          l1.type ="incorrect" 
          if (l1.type == "incorrect"){ # then build a logit reg of l1~l0 + a0
              l1.cov = y.cov[,c("l0.1","l0.2","a0")]
              l1.hat = glm.fit(l1.cov,l1.vec,family=binomial(),intercept=FALSE)$coeff
          }
          p.l1.func= function(a0,l0,beta, nb,l1.hat="holder"){
              if (class(l0) == "numeric"){
                  res = expit(c(l0, a0) %*% l1.hat)
              }else if (class(l0)=="matrix"){
                  # res = sapply(1:nb, function(i) {expit(c(l0[i,], a0[i]) %*% l1.hat)})
                  res = c(expit(cbind(l0, a0) %*% l1.hat))
              }
              return(res)
          }
          
          for (l0 in c(0,1)){
            for (a0 in c(0,1) ){
              for (a1 in c(0,1) ){
                l0.vec = c(1,l0)
                GComp.res[i,4*l0 + 2*a0 + a1 + 1] =GComp.res.i[4*l0 + 2*a0 + a1 + 1] = 
                  y.hat.func(l0=l0.vec,a0,l1=1,a1,y.hat) * p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat)+
                  y.hat.func(l0=l0.vec,a0,l1=0,a1,y.hat) * (1-p.l1.func(a0,l0.vec,beta.pmle, nb,l1.hat))
                # print(GComp.res.i[4*l0 + 2*a0 + a1 + 1])
              }
            }
          } # End of l0 loop
        } # End of doGComp
        if (doDebugEE){
            source("2.6 DR_EE_value.R")
            ee.val =
                EE.value(data, cnt, alpha.dr,beta.dr, gamma,
                         Hessian=FALSE,yhat)
            ee.u1[i] = ee.val$tmp.u1
            ee.u0[i] = ee.val$tmp.u0
            for (r.idx in 1:5){
                for (c.idx in 1:2){
                    ee.res[i,2*(r.idx-1) + c.idx] = ee.val$tmp[r.idx, c.idx]
                    ee.stage1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.stage1[r.idx, c.idx]
                    ee.d1[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d1[r.idx, c.idx]
                    ee.d0[i,2*(r.idx-1) + c.idx] = ee.val$tmp.d0[r.idx, c.idx]
                    
                }
            }
        }
        list(mle3=mle3, pmle4=pmle4, dr.sol=dr.sol, GComp.res.i=GComp.res.i)
    } # End of i = 1:numsims loop
    for (i in boot.list){
        tmp = boot1[[i-(iter-1)*per.iter]]
        if (doMLE){
            results[i,n.index,"MLE3",,,,] = tmp$mle3
        }
        if (doPMLE){
            results[i,n.index,"PMLE4",,,,] = tmp$pmle4
        }
        if (doDR){
            results[i,n.index,"DR",,,,]  = tmp$dr.sol
        }
        if (doGComp){
            GComp.res[i, ] = tmp$GComp.res.i
        }
        
    }
    save(results, GComp.res, params.true, file=paste("./RDataFiles/11Jun2021-parallel-n-1000-cts-l0-theta-not-0-l1-logit-iter-",iter,".RData",sep=""))
} # End iter
t1.total = proc.time()
t1.total - t0.total
#############
    # Unlist parallel result
    ##############
    
    
    
    # mean.res = results[1,n.index,"PMLE4",,,,]
    # n.index = 4
    
    curr.sim = numsims
    mean.res = GComp.res[1,]
    mean.res.pmle = results[1,n.index,"PMLE4",,,,]
    mean.res.DR = results[1,n.index,"DR",,,,]
    for(i in 2:curr.sim){
      mean.res.pmle = mean.res.pmle +results[i,n.index,"PMLE4",,,,]
      mean.res.DR = mean.res.DR + results[i,n.index,"DR",,,,]
      mean.res = mean.res + GComp.res[i,]
    }
    
    (mean.res.pmle/curr.sim)[, , 1, 1]
    (mean.res.DR/curr.sim)[, , 1, 1]
    mean.res/curr.sim
    
    wide = GComp.res[1:curr.sim,]
    colnames(wide) = c("000","001","010","011",
                      "100","101","110","111")
    wide = data.frame(wide)
    library(data.table)
    long = melt(setDT(wide),measure.vars = 1:8, variable.name = "l0a0a1")
    # cond.g.null.aov = aov(value~l0a0a1, data=long[1:(4*curr.sim),])
    # summary(cond.g.null.aov)
    marginal.data = data.frame(value=0.5*(long[1:(4*curr.sim),'value']+
                                             long[(1+4*curr.sim):(8*curr.sim),'value']),
                                a0a1 = rep(c("00","01","10","11"), each=curr.sim))
    marginal.g.null.aov = aov(value~a0a1, data=marginal.data) # because L0 is Bern(1/2)
    summary(marginal.g.null.aov)
    print(paste("This sample size takes",round((proc.time() - ptm1)[3],2),
                "s"))
    
# } # End for n.index in 1:length(n.cand))

rownames(gamma.true) = c("a0","a1")
datagen = DataGen(100000, SEED=0)
MyAttach(datagen)
params.true$contrast.true = GetContrast(data, alpha.true, beta.true)
rownames(gamma.true) = c("a0","a1")

print(paste("It takes",round((proc.time() - ptm0)[3],2),"s"))


save(results, GComp.res, params.true,dr.res.10, file=paste("./RDataFiles/8Jul2021-n-1000-theta-not-0-g-comp-y-misgop",SEED.index,".RData",sep=""))

cluster.path = "~/Desktop/causal res/SNMM_Local/3. Data and Programming/cluster/080820"
setwd(cluster.path)
load(paste("./RDataFiles/15Jul2021-n-1000-theta-0-g-comp-y-misgop",1,".RData",sep=""))

